Amazon SageMakerで作成したFactorization Machinesの分類モデルをローカルで推論させる
SageMakerには組み込みアルゴリズムの1つとしてFactorization Machines(FM)が用意されています。
今回はSageMakerで学習させたFMモデルを使ったローカルでの推論を試してみたので、その内容を紹介します。
やってみる
今回の焦点はローカルでのFMモデルの推論ですので、SageMakerでのモデルの学習部分は割愛します。 データセットはMNISTを使用し、モデルは以下のノートブックに則ってSageMakerで作成したFMの二値分類モデルを使用します。
事前準備
ライブラリを読み込み、モデルアーティファクトを保存している場所を定義しておきます。
import boto3 import json import pickle import gzip import urllib.request import mxnet as mx import numpy as np from os import path session = boto3.Session() bucket = 'hogehoge' model_path = 'sagemaker/DEMO-fm-mnist/output/factorization-machines-2019-08-15-09-38-15-958/output/model.tar.gz' model_file_name = path.basename(model_path)
データ準備
MNISTの手書き数字のデータセットをダウンロードし、展開します。 データセットには学習と検証、テスト用がありますが、今回は検証用を使って進めます。 元々は0~9の数字のどれかを予測する分類問題ですが、今回のモデルは0かどうかの二値分類のモデルなので、ラベルを書き換えます。
urllib.request.urlretrieve("http://deeplearning.net/data/mnist/mnist.pkl.gz", "mnist.pkl.gz") with gzip.open('mnist.pkl.gz', 'rb') as f: train_set, valid_set, test_set = pickle.load(f, encoding='latin1') vectors, labels = valid_set # 0かどうかの二値分類にする labels = (labels == 0).astype('float32')
モデルのダウンロードと展開
S3にあるモデルアーティファクトをダウンロード&展開し、MXNetで読み込める形に名前を変更します。
# FMモデルをダウンロードして展開 session.client('s3').download_file(bucket, model_path, model_file_name) os.system(f'tar xzvf {model_file_name}') os.system(f'unzip -o model_algo-1') os.system(f'mv symbol.json model-symbol.json') os.system(f'mv params model-0000.params')
モデルの読み込み
モデルのメタ情報がjsonファイルで提供されているので、読み込んでみます。
with open('meta.json', 'r') as f: meta = json.load(f) print(meta)
出力
{'label_names': ['out_label'], 'training_parameters': {'factors_lr': '0.0001', 'linear_init_sigma': '0.01', 'epochs': 1, 'feature_dim': '784', 'num_factors': '10', '_wd': '1.0', '_num_kv_servers': 'auto', 'use_bias': 'true', 'factors_init_sigma': '0.001', '_log_level': 'info', 'bias_init_method': 'normal', 'linear_init_method': 'normal', 'linear_lr': '0.001', 'factors_init_method': 'normal', '_tuning_objective_metric': '', 'bias_wd': '0.01', 'use_linear': 'true', 'bias_lr': '0.1', 'mini_batch_size': '200', '_use_full_symbolic': 'true', 'batch_metrics_publish_interval': '500', 'predictor_type': 'binary_classifier', 'bias_init_sigma': '0.01', '_num_gpus': 'auto', '_data_format': 'record', 'factors_wd': '0.00001', 'linear_wd': '0.001', '_kvstore': 'auto', '_learning_rate': '1.0', '_optimizer': 'adam'}, 'version': 1, 'epoch_number': 1}
学習時のパラメータなどを確認することができます。
次に、MXNetでモデルを読み込みます。
m = mx.module.Module.load('./model', epoch=0, label_names=meta['label_names'])
データをMXNetで扱いやすい形に変換し、そのデータ形式をモデルに設定します。
# MXNetの対応するデータに変換 validation_iter = mx.io.NDArrayIter(vectors, labels, label_name=meta['label_names'][0]) # データ形式を設定 m.bind(data_shapes=validation_iter.provide_data, label_shapes=validation_iter.provide_label)
推論
検証用のデータを推論し、推論結果をndarray形式に変換したものを表示します。
pred = m.predict(eval_data=validation_iter) print(pred.asnumpy())
出力
[[0.] [0.] [0.] ... [0.] [0.] [0.]]
※ 0が並んでいますが、出力は0から1の実数です。
さいごに
SageMakerで学習させたFMモデルを使ってローカルで推論する方法について紹介しました。誰かの役に立てば嬉しいです。